{ "cells": [ { "cell_type": "markdown", "id": "4249ef6e", "metadata": {}, "source": [ "# DL model with RSA\n", "This is a notebook used to investigate if deep learning model can decrypt message encrypted by RSA algorithm." ] }, { "cell_type": "code", "execution_count": 1, "id": "f9cb013c", "metadata": {}, "outputs": [], "source": [ "import math\n", "import pandas as pd\n", "import torch\n", "from torch import nn\n", "from d2l import torch as d2l\n", "import os\n", "import example #self-created package used to encrypt message" ] }, { "cell_type": "markdown", "id": "c4ff3a43", "metadata": {}, "source": [ "## Load Data and Preprocess" ] }, { "cell_type": "code", "execution_count": 2, "id": "f0d4e74b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Go.\tVa !\n", "Hi.\tSalut !\n", "Run!\tCours !\n", "Run!\tCourez !\n", "Who?\tQui ?\n", "Wow!\tÇa alors !\n", "\n" ] } ], "source": [ "# use English-to-Freach translation dataset\n", "# we will only use the English part to encrypt nad decrypt\n", "\n", "d2l.DATA_HUB['fra-eng'] = (d2l.DATA_URL + 'fra-eng.zip',\n", " '94646ad1522d915e7b0f9296181140edcf86a4f5')\n", "\n", "#@save\n", "def read_data_nmt():\n", " \"\"\"Load the English-French dataset.\"\"\"\n", " data_dir = d2l.download_extract('fra-eng')\n", " with open(os.path.join(data_dir, 'fra.txt'), 'r') as f:\n", " return f.read()\n", "\n", "text = read_data_nmt()\n", "print(text[:75])" ] }, { "cell_type": "code", "execution_count": 3, "id": "538ff2b0", "metadata": {}, "outputs": [], "source": [ "#preprocess\n", "def preprocess(text):\n", " \"\"\"Preprocess the English-French dataset.\"\"\"\n", " def no_space(char, prev_char):\n", " return char in set(',.!?') and prev_char != ' '\n", "\n", " # Replace non-breaking space with space, and convert uppercase letters to\n", " # lowercase ones\n", " text = text.replace('\\u202f', ' ').replace('\\xa0', ' ').lower()\n", " # Insert space between words and punctuation marks\n", " out = [\n", " ' ' + char if i > 0 and no_space(char, text[i - 1]) else char\n", " for i, char in enumerate(text)]\n", " return ''.join(out)" ] }, { "cell_type": "code", "execution_count": 4, "id": "c26f0519", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "go .\tva !\n", "hi .\tsalut !\n", "run !\tcours !\n", "run !\tcourez !\n", "who ?\tqui ?\n", "wow !\tça alors !\n" ] } ], "source": [ "text = preprocess(text)\n", "print(text[:80])" ] }, { "cell_type": "code", "execution_count": 5, "id": "91d06535", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[['go', '.'],\n", " ['hi', '.'],\n", " ['run', '!'],\n", " ['run', '!'],\n", " ['who', '?'],\n", " ['wow', '!']]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# tokenize rsa\n", "def tokenize(text, num_examples=None):\n", " \"\"\"\n", " Tokenize the English-French dataset.\n", " Only English is used for RSA\n", " \"\"\"\n", " target = []\n", " for i, line in enumerate(text.split('\\n')):\n", " if num_examples and i > num_examples:\n", " break\n", " parts = line.split('\\t')\n", " if len(parts) == 2:\n", " target.append(parts[0].split(' '))\n", " return target\n", "\n", "# English original sentence is our target, and\n", "# its encrypted message is our source\n", "target_rsa = tokenize(text)\n", "target_rsa[:6]" ] }, { "cell_type": "code", "execution_count": 6, "id": "7292dcdc", "metadata": {}, "outputs": [], "source": [ "#Truncate or pad sequences to ensure input has the same length/shape (num_steps)\n", "def truncate_pad(line, num_steps, padding_token):\n", " \"\"\"Truncate or pad sequences.\"\"\"\n", " if len(line) > num_steps:\n", " return line[:num_steps] # Truncate\n", " return line + [padding_token] * (num_steps - len(line)) # Pad" ] }, { "cell_type": "code", "execution_count": 7, "id": "ed0fe3b3", "metadata": {}, "outputs": [], "source": [ "# Create a RSA library used to do Encryption \n", "rsa_lib = example.RSA_lib()" ] }, { "cell_type": "code", "execution_count": 8, "id": "8141bfea", "metadata": {}, "outputs": [], "source": [ "def build_array_rsa(lines, vocab, num_steps):\n", " \"\"\"Transform text sequences of machine translation into minibatches.\"\"\"\n", " lines = [vocab[l] for l in lines] # string to its indices in vocabulary\n", " lines = [l + [vocab['']] for l in lines] # add final end-of-sentence symbol\n", " # truncate or pad to ensure the same shape \n", " array_tgt = torch.tensor([truncate_pad(l, num_steps, vocab['']) for l in lines])\n", " # find valid length\n", " valid_len = (array_tgt != vocab['']).type(torch.int32).sum(1)\n", " \n", " # compute soruce array (input X of Transfomer) by encrypting\n", " array_src_raw = torch.tensor([ rsa_lib.encode(l) for line in array_tgt for l in line]).reshape(array_tgt.shape)\n", " \n", " return array_src_raw, array_tgt, valid_len" ] }, { "cell_type": "code", "execution_count": 9, "id": "8db487bd", "metadata": {}, "outputs": [], "source": [ "# create rsa source vocabulary that is a one-to-one map to encrypted message\n", "# the reason is power over a number is very large here \n", "def rsa_src_vocab(array_raw):\n", " # convert tensor to a list\n", " a = []\n", " for lines in array_raw:\n", " ls = []\n", " for l in lines:\n", " ls.append(l.item())\n", " a.append(ls)\n", " # then create the Vocabulary \n", " rsa_src_vocab = d2l.Vocab(a)\n", " \n", " # finally convert the arrary \n", " array_src = torch.tensor([ rsa_src_vocab[l] for line in a for l in line]).reshape(array_raw.shape)\n", " return rsa_src_vocab, array_src" ] }, { "cell_type": "code", "execution_count": 10, "id": "0019cd52", "metadata": {}, "outputs": [], "source": [ "def load_data_rsa(batch_size, num_steps, num_examples=600):\n", " \"\"\"Return the iterator and the vocabularies of the translation dataset.\"\"\"\n", " text = preprocess(read_data_nmt())\n", " target = tokenize(text, num_examples)\n", " tgt_vocab = d2l.Vocab(target, min_freq=2, reserved_tokens=['', '', ''])\n", " \n", " src_array_raw, tgt_array, tgt_valid_len = build_array_rsa(target, tgt_vocab, num_steps)\n", " src_vocab, src_array = rsa_src_vocab(src_array_raw)\n", " \n", " # target valid length now is equal to source target length \n", " data_arrays = (src_array, tgt_valid_len, tgt_array, tgt_valid_len)\n", " data_iter = d2l.load_array(data_arrays, batch_size)\n", " return data_iter, src_vocab, tgt_vocab" ] }, { "cell_type": "code", "execution_count": 11, "id": "96e2dba9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "X: tensor([[ 9, 28, 3, 2, 1, 1, 1, 1, 1, 1],\n", " [163, 34, 4, 2, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)\n", "valid lengths for X: tensor([4, 4])\n", "Y: tensor([[ 9, 28, 4, 3, 1, 1, 1, 1, 1, 1],\n", " [163, 34, 5, 3, 1, 1, 1, 1, 1, 1]], dtype=torch.int32)\n", "valid lengths for Y: tensor([4, 4])\n" ] } ], "source": [ "train_iter, src_vocab, tgt_vocab = load_data_rsa(batch_size=2, num_steps=10)\n", "for X, X_valid_len, Y, Y_valid_len in train_iter:\n", " print('X:', X.type(torch.int32))\n", " print('valid lengths for X:', X_valid_len)\n", " print('Y:', Y.type(torch.int32))\n", " print('valid lengths for Y:', Y_valid_len)\n", " break" ] }, { "cell_type": "markdown", "id": "d363f15b", "metadata": {}, "source": [ "## Transformer\n", "We first give the network structure. " ] }, { "cell_type": "code", "execution_count": 12, "id": "7a12d955", "metadata": {}, "outputs": [], "source": [ "class PositionWiseFFN(nn.Module):\n", " \"\"\"Positionwise feed-forward network.\"\"\"\n", " def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,\n", " **kwargs):\n", " super(PositionWiseFFN, self).__init__(**kwargs)\n", " self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)\n", " self.relu = nn.ReLU()\n", " self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)\n", "\n", " def forward(self, X):\n", " return self.dense2(self.relu(self.dense1(X)))" ] }, { "cell_type": "code", "execution_count": 13, "id": "f508779f", "metadata": {}, "outputs": [], "source": [ "class AddNorm(nn.Module):\n", " \"\"\"Residual connection followed by layer normalization.\"\"\"\n", " def __init__(self, normalized_shape, dropout, **kwargs):\n", " super(AddNorm, self).__init__(**kwargs)\n", " self.dropout = nn.Dropout(dropout)\n", " self.ln = nn.LayerNorm(normalized_shape)\n", "\n", " def forward(self, X, Y):\n", " return self.ln(self.dropout(Y) + X)" ] }, { "cell_type": "code", "execution_count": 14, "id": "9d01fc07", "metadata": {}, "outputs": [], "source": [ "class EncoderBlock(nn.Module):\n", " \"\"\"Transformer encoder block.\"\"\"\n", " def __init__(self, key_size, query_size, value_size, num_hiddens,\n", " norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,\n", " dropout, use_bias=False, **kwargs):\n", " super(EncoderBlock, self).__init__(**kwargs)\n", " self.attention = d2l.MultiHeadAttention(key_size, query_size,\n", " value_size, num_hiddens,\n", " num_heads, dropout, use_bias)\n", " self.addnorm1 = AddNorm(norm_shape, dropout)\n", " self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,\n", " num_hiddens)\n", " self.addnorm2 = AddNorm(norm_shape, dropout)\n", "\n", " def forward(self, X, valid_lens):\n", " Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))\n", " return self.addnorm2(Y, self.ffn(Y))" ] }, { "cell_type": "code", "execution_count": 15, "id": "e647167a", "metadata": {}, "outputs": [], "source": [ "class TransformerEncoder(d2l.Encoder):\n", " \"\"\"Transformer encoder.\"\"\"\n", " def __init__(self, vocab_size, key_size, query_size, value_size,\n", " num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,\n", " num_heads, num_layers, dropout, use_bias=False, **kwargs):\n", " super(TransformerEncoder, self).__init__(**kwargs)\n", " self.num_hiddens = num_hiddens\n", " self.embedding = nn.Embedding(vocab_size, num_hiddens)\n", " self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)\n", " self.blks = nn.Sequential()\n", " for i in range(num_layers):\n", " self.blks.add_module(\n", " \"block\" + str(i),\n", " EncoderBlock(key_size, query_size, value_size, num_hiddens,\n", " norm_shape, ffn_num_input, ffn_num_hiddens,\n", " num_heads, dropout, use_bias))\n", "\n", " def forward(self, X, valid_lens, *args):\n", " # Since positional encoding values are between -1 and 1, the embedding\n", " # values are multiplied by the square root of the embedding dimension\n", " # to rescale before they are summed up\n", " X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))\n", " self.attention_weights = [None] * len(self.blks)\n", " for i, blk in enumerate(self.blks):\n", " X = blk(X, valid_lens)\n", " self.attention_weights[\n", " i] = blk.attention.attention.attention_weights\n", " return X" ] }, { "cell_type": "code", "execution_count": 16, "id": "5dd12fc7", "metadata": {}, "outputs": [], "source": [ "class DecoderBlock(nn.Module):\n", " # The `i`-th block in the decoder\n", " def __init__(self, key_size, query_size, value_size, num_hiddens,\n", " norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,\n", " dropout, i, **kwargs):\n", " super(DecoderBlock, self).__init__(**kwargs)\n", " self.i = i\n", " self.attention1 = d2l.MultiHeadAttention(key_size, query_size,\n", " value_size, num_hiddens,\n", " num_heads, dropout)\n", " self.addnorm1 = AddNorm(norm_shape, dropout)\n", " self.attention2 = d2l.MultiHeadAttention(key_size, query_size,\n", " value_size, num_hiddens,\n", " num_heads, dropout)\n", " self.addnorm2 = AddNorm(norm_shape, dropout)\n", " self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,\n", " num_hiddens)\n", " self.addnorm3 = AddNorm(norm_shape, dropout)\n", "\n", " def forward(self, X, state):\n", " enc_outputs, enc_valid_lens = state[0], state[1]\n", " # During training, all the tokens of any output sequence are processed\n", " # at the same time, so `state[2][self.i]` is `None` as initialized.\n", " # When decoding any output sequence token by token during prediction,\n", " # `state[2][self.i]` contains representations of the decoded output at\n", " # the `i`-th block up to the current time step\n", " if state[2][self.i] is None:\n", " key_values = X\n", " else:\n", " key_values = torch.cat((state[2][self.i], X), axis=1)\n", " state[2][self.i] = key_values\n", " if self.training:\n", " batch_size, num_steps, _ = X.shape\n", " # Shape of `dec_valid_lens`: (`batch_size`, `num_steps`), where\n", " # every row is [1, 2, ..., `num_steps`]\n", " dec_valid_lens = torch.arange(1, num_steps + 1,\n", " device=X.device).repeat(\n", " batch_size, 1)\n", " else:\n", " dec_valid_lens = None\n", "\n", " # Self-attention\n", " X2 = self.attention1(X, key_values, key_values, dec_valid_lens)\n", " Y = self.addnorm1(X, X2)\n", " # Encoder-decoder attention. Shape of `enc_outputs`:\n", " # (`batch_size`, `num_steps`, `num_hiddens`)\n", " Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)\n", " Z = self.addnorm2(Y, Y2)\n", " return self.addnorm3(Z, self.ffn(Z)), state" ] }, { "cell_type": "code", "execution_count": 17, "id": "2ee27184", "metadata": {}, "outputs": [], "source": [ "class TransformerDecoder(d2l.AttentionDecoder):\n", " def __init__(self, vocab_size, key_size, query_size, value_size,\n", " num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,\n", " num_heads, num_layers, dropout, **kwargs):\n", " super(TransformerDecoder, self).__init__(**kwargs)\n", " self.num_hiddens = num_hiddens\n", " self.num_layers = num_layers\n", " self.embedding = nn.Embedding(vocab_size, num_hiddens)\n", " self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)\n", " self.blks = nn.Sequential()\n", " for i in range(num_layers):\n", " self.blks.add_module(\n", " \"block\" + str(i),\n", " DecoderBlock(key_size, query_size, value_size, num_hiddens,\n", " norm_shape, ffn_num_input, ffn_num_hiddens,\n", " num_heads, dropout, i))\n", " self.dense = nn.Linear(num_hiddens, vocab_size)\n", "\n", " def init_state(self, enc_outputs, enc_valid_lens, *args):\n", " return [enc_outputs, enc_valid_lens, [None] * self.num_layers]\n", "\n", " def forward(self, X, state):\n", " X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))\n", " self._attention_weights = [[None] * len(self.blks) for _ in range(2)]\n", " for i, blk in enumerate(self.blks):\n", " X, state = blk(X, state)\n", " # Decoder self-attention weights\n", " self._attention_weights[0][\n", " i] = blk.attention1.attention.attention_weights\n", " # Encoder-decoder attention weights\n", " self._attention_weights[1][\n", " i] = blk.attention2.attention.attention_weights\n", " return self.dense(X), state\n", "\n", " @property\n", " def attention_weights(self):\n", " return self._attention_weights" ] }, { "cell_type": "markdown", "id": "ca94ea72", "metadata": {}, "source": [ "## Train and Predict" ] }, { "cell_type": "code", "execution_count": 18, "id": "1e522c8b", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss 0.003, 6850.1 tokens/sec on cpu\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2021-09-27T22:27:59.884881\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.4.3, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10\n", "lr, num_epochs, device = 0.005, 200, d2l.try_gpu()\n", "ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4\n", "key_size, query_size, value_size = 32, 32, 32\n", "norm_shape = [32]\n", "\n", "train_iter, src_vocab, tgt_vocab = load_data_rsa(batch_size, num_steps)\n", "\n", "encoder = TransformerEncoder(len(src_vocab), key_size, query_size, value_size,\n", " num_hiddens, norm_shape, ffn_num_input,\n", " ffn_num_hiddens, num_heads, num_layers, dropout)\n", "decoder = TransformerDecoder(len(tgt_vocab), key_size, query_size, value_size,\n", " num_hiddens, norm_shape, ffn_num_input,\n", " ffn_num_hiddens, num_heads, num_layers, dropout)\n", "net = d2l.EncoderDecoder(encoder, decoder)\n", "d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)" ] }, { "cell_type": "markdown", "id": "7d9cba14", "metadata": {}, "source": [ "Transformer seems to be able to decrypt message. " ] }, { "cell_type": "code", "execution_count": 19, "id": "27bc4cbd", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "def message_to_X(raw_sentence, raw_vocab, num_steps):\n", " \"\"\"convert message to indices of encrypted message in source vocabublary\"\"\"\n", " # print(raw_sentence)\n", " raw_inds = raw_vocab[raw_sentence.lower().split(' ')] + [raw_vocab['']]\n", " # print(raw_inds)\n", " valid_length = len(raw_inds)\n", " # print(valid_length)\n", " raw_inds = d2l.truncate_pad(raw_inds, num_steps, raw_vocab[''])\n", " encrypted_mess = [rsa_lib.encode(index) for index in raw_inds]\n", " return torch.tensor(src_vocab[encrypted_mess]), torch.tensor([valid_length])" ] }, { "cell_type": "code", "execution_count": 20, "id": "2fa1cb49", "metadata": {}, "outputs": [], "source": [ "def predict_seq2seq(net, src_tokens, enc_valid_len, tgt_vocab, num_steps,\n", " device, save_attention_weights=False):\n", " \"\"\"Predict for sequence to sequence.\"\"\"\n", " # evaluation mode\n", " net.eval()\n", " # Add the batch axis\n", " enc_X = torch.unsqueeze(\n", " torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)\n", " enc_outputs = net.encoder(enc_X, enc_valid_len)\n", " dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)\n", " # Add the batch axis\n", " dec_X = torch.unsqueeze(torch.tensor(\n", " [tgt_vocab['']], dtype=torch.long, device=device), dim=0)\n", " output_seq, attention_weight_seq = [], []\n", " for _ in range(num_steps):\n", " Y, dec_state = net.decoder(dec_X, dec_state)\n", " # We use the token with the highest prediction likelihood as the input\n", " # of the decoder at the next time step\n", " dec_X = Y.argmax(dim=2)\n", " pred = dec_X.squeeze(dim=0).type(torch.int32).item()\n", " # Save attention weights (to be covered later)\n", " if save_attention_weights:\n", " attention_weight_seq.append(net.decoder.attention_weights)\n", " # Once the end-of-sequence token is predicted, the generation of the\n", " # output sequence is complete\n", " if pred == tgt_vocab['']:\n", " break\n", " output_seq.append(pred)\n", " return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq" ] }, { "cell_type": "code", "execution_count": 21, "id": "8e3364c1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "go .\n", "i lost .\n", "he's calm .\n", "i'm home .\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/jx/l2fp0rys0t30l4wlc1m6645r0000gn/T/ipykernel_28950/2525470819.py:8: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)\n" ] } ], "source": [ "engs = ['go .', \"i lost .\", 'he\\'s calm .', 'i\\'m home .'] # target\n", "for eng in engs:\n", " x, valid_len = message_to_X(eng, tgt_vocab, 10)\n", " translation, dec_attention_weight_seq = predict_seq2seq(\n", " net, x, valid_len, tgt_vocab, num_steps, device, True)\n", " print(translation)" ] }, { "cell_type": "markdown", "id": "390dfc94", "metadata": {}, "source": [ "Pefectly right!!!" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.11" } }, "nbformat": 4, "nbformat_minor": 5 }